import logging,os,wandb, time
import collections

import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import DataLoader

import torch.cuda as cuda
import torch.distributed as distributed
import torch.multiprocessing as multiprocessing
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.utils.data.distributed import DistributedSampler as Sampler

from ..dataset.dataset import *
from ..dataset.utils import *
from ..models.classifier import *
from ..models.kkt_parameter import *
from ..models.generator import *
from .utils import *
from .losses import *
from .sampler import *

from ..utils import *

logger = logging.getLogger(module_structure(__file__))


def create_guider(cfg, trainset, *args, **kwargs):
    device = cfg.device

    num_classes = trainset.num_classes

    if cfg.Classifier.Name == "LinearNet":
        guider = LinearNet(cfg, *args, num_classes = num_classes, input_shape= trainset.input_shape, **kwargs)
    elif cfg.Classifier.Name == "ConvNet":
        guider = ConvNet(cfg, *args, num_channels = trainset.num_channels, num_classes = num_classes, **kwargs)
    elif cfg.Classifier.Name == "ResNet":
        guider = ResNet(cfg, *args, num_channels = trainset.num_channels, num_classes = num_classes, **kwargs)
    else:
        raise NotImplementedError(f'cfg.Classifier.Name: {cfg.Classifier.Name}')
    logger.info(f'classifier {cfg.Classifier.Name} is built')
    guider.load_state_dict(torch.load(cfg.Classifier.load_path))
    logger.info(f'classifier is loaded')
    guider  = guider.to(device)

    if cfg.Generator.normalization_match_coe > 0.0:
        normalization_stat = get_normalization_layer_stats(guider)

    if cfg.ddp:
        guider = DDP(guider,device_ids=[cfg.local_rank])
        logger.info(f'No.{i} classifier is on DDP mode')

    return guider, normalization_stat

def create_Lambda(cfg, guider, *args, **kwargs):
    device = cfg.device

    if cfg.Lambda.use:
        if cfg.Lambda.path:
            Lambda= torch.load(cfg.Lambda.path)
            logger.info(f'Lambda is loaded')
        else:
            Lambda = LambdaModule(guider, (torch.rand(cfg.Lambda.num_samples,cfg.Classifier.input_shape)-0.5)*50)
            logger.info(f'Lambda is')
            logger.info(f'{Lambda.Lambda_dict}')
    else:
        Lambda= None
        logger.info(f'Lambda is not used')

    return Lambda

def create_alpha(cfg, *args, **kwargs):
    device = cfg.device

    if cfg.alpha.use:
        alpha = AlphaModule(cfg, *args, **kwargs)
        logger.info(f'alpha module is built')
        alpha = alpha.to(device)
        if cfg.ddp:
            alpha = DDP(alpha,device_ids=[cfg.local_rank])
            logger.info(f'alpha is on DDP mode')
            alpha_optimizer = ZeRO(alpha.parameters(), optim.Adam, lr=cfg.alpha.lr, weight_decay=cfg.alpha.weight_decay)
        else:
            alpha_optimizer = optim.Adam( alpha.parameters(),lr=cfg.alpha.lr, weight_decay=cfg.alpha.weight_decay)
    else:
        alpha = None
        alpha_optimizer = None
        logger.info(f'alpha module not used')

    return alpha, alpha_optimizer

def create_muer(cfg, trainset, *args, **kwargs):
    device = cfg.device

    if cfg.muer.use:
        if cfg.muer.Name == "linear_xy":
            muer = Muer_Linear_InputXY(cfg, *args, num_classes = trainset.num_classes, **kwargs)
        elif cfg.muer.Name == "conv_xy":
            muer = Muer_Conv_InputXY(cfg, *args, num_classes = trainset.num_classes, **kwargs)
        else:
            raise NotImplementedError(f'cfg.muer.Name: {cfg.muer.Name}')
        logger.info(f'muer is built')
        muer = muer.to(device)
        if cfg.ddp:
            muer = DDP(muer,device_ids=[cfg.local_rank])
            logger.info(f'muer is on DDP mode')
            muer_optimizer = ZeRO(muer.parameters(), optim.Adam, lr=cfg.muer.lr, weight_decay=cfg.muer.weight_decay)
        else:
            muer_optimizer = optim.Adam(muer.parameters(), lr=cfg.muer.lr, weight_decay=cfg.muer.weight_decay)
    else:
        muer = None
        muer_optimizer = None
        logger.info(f'muer module not used')

    return muer, muer_optimizer

def create_sampler(cfg, trainset, *args, **kwargs):
    sampler = InputSampler(cfg, num_classes = trainset.num_classes, *args, **kwargs)
    logger.info(f'input sampler is built')
    return sampler

def init_running_loss():
    running_losses = {}
    running_losses["kkt"] = {"loss_kkt":0.0, "lagrange_loss":0.0, "duality_loss":0.0}  
    running_losses["ensemble"] = {"loss":0.0}   
    running_losses["regu"] = {"loss":0.0}   
    running_losses["match"] = {"loss":0.0}   
    return running_losses
    

def train(cfg, *args, **kwargs):
    randomness_control(cfg.seed)
    time_dict= collections.OrderedDict()
    time_dict["total_time"] = [time.time()]
    device = cfg.device

    time_dict["data"] = [time.time()]
    trainset, testset = load_dataset(cfg.Data.path)
    logger.info(f'Dataset is built')
    training_input = extract_input(trainset)
    training_input = training_input.to(device)
    training_label = extract_output(trainset)
    training_label = training_label.to(device)
    logger.info(f'Training input is extracted')
    training_sample_amount = len(trainset)
    time_dict["data"].append(time.time())

    time_dict["guider"] = [time.time()]
    # load pretrained classifier
    guider, normalization_stat = create_guider(cfg, trainset, *args, **kwargs)
    time_dict["guider"].append(time.time())

    # calculate Lambda for each classifier
    Lambda = create_Lambda(cfg, guider, *args, **kwargs)

    time_dict["generator"] = [time.time()]
    # create generator
    if cfg.Generator.Name == "DCG":
        net = DCG(cfg, *args, num_classes = trainset.num_classes, num_channels = trainset.num_channels, **kwargs)
    elif cfg.Generator.Name == "LinearG":
        net = LinearG(cfg, *args, num_classes = trainset.num_classes, input_shape = trainset.input_shape, **kwargs)
    elif cfg.Generator.Name == "IdentityG":
        net = IdentityG(cfg, *args, input_shape = trainset.input_shape, **kwargs)
    else:
        raise NotImplementedError(f'cfg.Generator.Name: {cfg.Generator.Name}')
    logger.info(f'generator {cfg.Generator.Name} is built')
    net = net.to(device)
    if cfg.ddp:
        net = DDP(net,device_ids=[cfg.local_rank])
        logger.info(f'generator is on DDP mode')
        net_optimizer = ZeRO(net.parameters(), optim.Adam, lr=cfg.Generator.lr, weight_decay=cfg.Generator.weight_decay)
    else:
        net_optimizer = optim.Adam(net.parameters(), lr=cfg.Generator.lr, weight_decay=cfg.Generator.weight_decay)
    time_dict["generator"].append(time.time())

    # create alphas
    alpha, alpha_optimizer = create_alpha(cfg, *args, **kwargs)

    # create muers
    muer, muer_optimizer = create_muer(cfg, trainset, *args, **kwargs)

    time_dict["loss"] = [time.time()]
    # define losses
    criterion_kkt = KKT_loss(cfg, *args, guider = guider, Lambda = Lambda, num_samples = trainset.num_samples, **kwargs)
    time_dict["loss"].append(time.time())

    time_dict["sampler"] = [time.time()]
    # rarndom sampler
    input_sampler = create_sampler(cfg, trainset, *args, **kwargs)
    time_dict["sampler"].append(time.time())

    # create loss log
    running_losses = init_running_loss()
    
    time_dict["iter"] = {
        "iter_time":[],
        "create_sample":[],
        "create_x":[],
        "classification":[],
        "loss_kkt":[],
        "loss_regu":[],
        "backward":[],
        "log":[],
        "save_img":[],
    }
    min_similarity = 1.0
    # training
    logger.info(f"Training loop starts")
    for epoch in range(cfg.Generator.epoches):  # loop over the dataset multiple times
        time_dict["iter"]["iter_time"].append(time.time())
        if cfg.ddp:
            distributed.barrier()

        time_dict["iter"]["create_sample"].append(time.time())
        noise, y = input_sampler(cfg.Generator.batch_size)
        noise, y = noise.to(device), y.to(device)
        time_dict["iter"]["create_sample"].append(time.time())

        # zero the parameter gradients
        net_optimizer.zero_grad()
        if muer_optimizer is not None:
            muer_optimizer.zero_grad()
        if alpha_optimizer is not None:
            alpha_optimizer.zero_grad()

        time_dict["iter"]["create_x"].append(time.time())
        x = net(noise, y)
        time_dict["iter"]["create_x"].append(time.time())

        guider.eval()
        time_dict["iter"]["classification"].append(time.time())
        pred = guider(x)
        time_dict["iter"]["classification"].append(time.time())

        if muer is not None:
            y_mu = y[:,0] if cfg.Generator.double_label else y
            mus = muer(x, y_mu)
        else:
            mus = None

        if alpha is not None:
            alpha_value = alpha()
        else:
            alpha_value = None
            
        time_dict["iter"]["loss_kkt"].append(time.time())
        loss_kkt, lagrange_loss, duality_loss = criterion_kkt(x, y, mus, pred, alpha_value, guider)
        loss = loss_kkt
        time_dict["iter"]["loss_kkt"].append(time.time())

        time_dict["iter"]["loss_regu"].append(time.time())
        if cfg.Generator.regulariztion_name == "tv":
            regularziation_loss = total_variation_loss(x, cfg.Generator.tv_power) 
        elif cfg.Generator.regulariztion_name == "mul_tv":
            regularziation_loss = Multiply_total_variation_loss(x, cfg.Generator.tv_power) 
        loss = loss + regularziation_loss * cfg.Generator.regularization_coe
        time_dict["iter"]["loss_regu"].append(time.time())

        time_dict["iter"]["backward"].append(time.time())
        loss.backward()
        net_optimizer.step()

        if muer_optimizer is not None:
            muer_optimizer.step()
        if alpha_optimizer is not None:
            alpha_optimizer.step()
        time_dict["iter"]["backward"].append(time.time())
        time_dict["iter"]["iter_time"].append(time.time())

        # print statistics
        running_losses['kkt']["loss_kkt"] += loss_kkt.item()
        running_losses['kkt']["lagrange_loss"] += lagrange_loss.item()
        running_losses['kkt']["duality_loss"] += duality_loss.item()

        running_losses["regu"]["loss"] += regularziation_loss.item()

        if epoch % cfg.Generator.evaluation_interval == (cfg.Generator.evaluation_interval - 1):    # print every 2000 mini-batches
            time_dict["iter"]["log"].append(time.time())
            output_string = f"[epoch: {epoch + 1:6d}/{cfg.Generator.epoches}]"
            output_string = output_string + f" |KKT loss:| "
            for loss_name, loss_sum in running_losses['kkt'].items():
                output_string = output_string + f' {loss_name}: {loss_sum/cfg.Generator.evaluation_interval:5.3f} '

            output_string = output_string + f" |regularziation loss:| "
            for loss_name, loss_sum in running_losses["regu"].items():
                output_string = output_string + f' {loss_name}: {loss_sum/cfg.Generator.evaluation_interval:5.3f} '

            logger.info(output_string)
            time_dict["iter"]["log"].append(time.time())

            if epoch % cfg.Generator.image_save_interval == (cfg.Generator.image_save_interval - 1): 
                time_dict["iter"]["save_img"].append(time.time())
                if cfg.Data.Name.startswith("MNIST") or cfg.Data.Name.startswith("Cifar") or cfg.Data.Name.startswith("celeba"):
                    similar_visualization_path = os.path.join(cfg.pathes.img_path, f"similar_{epoch}.jpg")
                    similarity_score, similar_visualization_grid = similar_visualization(x, training_input, similar_visualization_path)
                    if min_similarity > similarity_score:
                        min_similarity = similarity_score
                    logger.info(f"\t similarity score: {similarity_score:10.4f}")
                    
                    random_visualization_path = os.path.join(cfg.pathes.img_path, f"random_{epoch}.jpg")
                    random_visualization_grid = visualization(x, random_visualization_path)

                elif cfg.Data.Name.startswith("2D"):
                    visualization_path = os.path.join(cfg.pathes.img_path, f"{epoch}.jpg")
                    if cfg.Generator.double_label:
                        plot_y = total_y[:,0]
                    else:
                        plot_y = total_y
                    matplotlib_fig,_ = visualization_2D(total_x, plot_y, training_input, training_label, visualization_path)

                if cfg.Generator.save_along:
                    if cfg.ddp:
                        torch.save(net.module.cpu().state_dict(), cfg.Generator.path + f".{epoch}")
                        logger.info(f'Generator is saved to {cfg.Generator.path}.{epoch}')
                        net = net.module.to(device)
                    else:
                        torch.save(net.cpu().state_dict(), cfg.Generator.path + f".{epoch}")
                        logger.info(f'Generator is saved to {cfg.Generator.path}.{epoch}')
                        net = net.to(device)
                time_dict["iter"]["save_img"].append(time.time())

            if cfg.wandb:
                wandb_log_dict = {"kkt":{}}

                wandb_log_dict["kkt"] = {}
                for loss_name, loss_sum in running_losses['kkt'].items():
                    wandb_log_dict["kkt"][loss_name] = loss_sum/cfg.Generator.evaluation_interval

                wandb_log_dict["regularization"] = {}
                for loss_name, loss_sum in running_losses["regu"].items():
                    wandb_log_dict["regularization"][loss_name] = loss_sum/cfg.Generator.evaluation_interval

                if epoch % cfg.Generator.image_save_interval == (cfg.Generator.image_save_interval - 1): 
                    if cfg.Data.Name.startswith("MNIST") or cfg.Data.Name.startswith("Cifar") or cfg.Data.Name.startswith("celeba"):
                        wandb_log_dict['similarity_score'] = similarity_score
                        wandb_log_dict['similar visualization'] = wandb.Image(image_from_torch_to_numpy(similar_visualization_grid), caption=f"")
                        
                        wandb_log_dict['random_visualization'] = wandb.Image(image_from_torch_to_numpy(random_visualization_grid), caption=f"")

                    elif cfg.Data.Name.startswith("2D"):
                        wandb_log_dict['visualization'] = wandb.Image(matplotlib_fig, caption=f"")

                wandb.log(wandb_log_dict, int(epoch + 1))
                
            running_losses = init_running_loss()

    logger.info('Finished Training')

    if cfg.ddp:
        torch.save(net.module.cpu().state_dict(), cfg.Generator.path)
        logger.info(f'Generator is saved to {cfg.Generator.path}')
        torch.save(muer.module.cpu().state_dict(), cfg.muer.path)
        logger.info(f'muer is saved to {cfg.muer.path}')
        torch.save(alpha.module.cpu().state_dict(), cfg.alpha.path)
        logger.info(f'alpha is saved to {cfg.alpha.path}')
    else:
        torch.save(net.cpu().state_dict(), cfg.Generator.path)
        logger.info(f'Generator is saved to {cfg.Generator.path}')
        if muer is not None:
            torch.save(muer.cpu().state_dict(), cfg.muer.path + f".{muer_index}")
            logger.info(f'muer is saved to {cfg.muer.path}.{muer_index}')
        if alpha is not None:
            torch.save(alpha.cpu().state_dict(), cfg.alpha.path + f".{alpha_index}")
            logger.info(f'alpha is saved to {cfg.alpha.path}.{alpha_index}')

    time_dict["total_time"].append(time.time())

    time_stat = output_time("", 0, time_dict)
    print(time_stat)
    logger.info("\n" + time_stat)
    return {"similarity": similarity_score, "min similarity":min_similarity}